% Forward Backward proof-of-concept code
% Coded by Qun Feng Tan
close all
clear all
clc
p = 0.25;
pic = imread('einstein_bit.jpg');
height = size(pic,1);
width = size(pic,2);
pic_tran = pic';
pic_vec = pic_tran(:);
pic_mean = mean(pic_vec);
B = pic_vec>pic_mean;
img_size = numel(pic);
%-----calculating emprical distributions-------%
p_x = [1 - sum(B)/img_size sum(B)/img_size];
bins = zeros(1,4);
for i = 2:img_size
    bins(B(i-1)*2 + B(i) + 1) = bins(B(i-1)*2 + B(i) + 1) + 1;
end
p_2x = bins/img_size;
bins = zeros(1,8);
for i = 3:img_size
    bins(B(i-2)*4 + B(i-1)*2 + B(i) + 1) = bins(B(i-2)*4 + B(i-1)*2 + B(i) + 1) + 1;
end
p_3x = bins/img_size;
bins = zeros(2,2);
for i = 2:img_size
    bins(B(i-1) + 1,B(i) + 1) = bins(B(i-1) + 1,B(i) + 1) + 1;
end
p_x_1 = [bins(:,1)./sum(bins,2) bins(:,2)./sum(bins,2)];
bins = zeros(4,2);
for i = 3:img_size
    bins(B(i-2)*2 + B(i-1) + 1,B(i) + 1) = bins(B(i-2)*2 + B(i-1) + 1,B(i) + 1) + 1;
end
p_x_2 = [bins(:,1)./sum(bins,2) bins(:,2)./sum(bins,2)];
bins = zeros(8,2);
for i = 4:img_size
    bins(B(i-3)*4 + B(i-2)*2 + B(i-1) + 1,B(i) + 1) = bins(B(i-3)*4 + B(i-2)*2 + B(i-1) + 1,B(i) + 1) + 1;
end
p_x_3 = [bins(:,1)./sum(bins,2) bins(:,2)./sum(bins,2)];
Y = xor(B,(rand(img_size,1)<p)); % Noise corruption
bins = zeros(2,2);
for i = 1:img_size
    bins(B(i) + 1,Y(i) + 1) = bins(B(i) + 1,Y(i) + 1) + 1;
end
p_y_1 = [bins(:,1)./sum(bins,2) bins(:,2)./sum(bins,2)];
bins = zeros(4,2);
for i = 2:img_size
    bins(B(i-1)*2 + B(i) + 1,Y(i) + 1) = bins(B(i-1)*2 + B(i) + 1,Y(i) + 1) + 1;
end
p_y_2 = [bins(:,1)./sum(bins,2) bins(:,2)./sum(bins,2)];
bins = zeros(8,2);
for i = 3:img_size
    bins(B(i-2)*4 + B(i-1)*2 + B(i) + 1,Y(i) + 1) = bins(B(i-2)*4 + B(i-1)*2 + B(i) + 1,Y(i) + 1) + 1;
end
p_y_3 = [bins(:,1)./sum(bins,2) bins(:,2)./sum(bins,2)];
%----1st iteration forward-----%
aaa = zeros(1,2,img_size);
bbb = p_x;
for i = 1:img_size
    aaa(:,:,i) = bbb.*p_y_1(:,Y(i) + 1)' ./ sum(bbb.*p_y_1(:,Y(i) + 1)');
    bbb = aaa(:,:,i) * p_x_1;
end
forward1 = aaa(1,2,:) > 0.5;
%----1st iteration backwards-----%
ccc = aaa;
for j = 1:img_size - 1
    i = img_size - j;
    ccc(:,:,i) = ccc(:,:,i+1)*[aaa(:,:,i).*p_x_1(:,1)'/sum(aaa(:,:,i).*p_x_1(:,1)'); aaa(:,:,i).*p_x_1(:,2)'/sum(aaa(:,:,i).*p_x_1(:,2)')];
end
backward1 = ccc(1,2,:) > 0.5;
%----2nd iteration forward-----%
aaa = zeros(1,4,img_size);
bbb = p_2x;
for i = 2:img_size
    aaa(:,:,i) = bbb.*p_y_2(:,Y(i) + 1)' ./ sum(bbb.*p_y_2(:,Y(i) + 1)');
    bbb = [ [aaa(1,1,i) aaa(1,3,i)]*[p_x_2(1,1);p_x_2(3,1)]   [aaa(1,1,i) aaa(1,3,i)]*[p_x_2(1,2);p_x_2(3,2)]   [aaa(1,2,i) aaa(1,4,i)]*[p_x_2(2,1);p_x_2(4,1)]   [aaa(1,2,i) aaa(1,4,i)]*[p_x_2(2,2);p_x_2(4,2)] ];
end
forward2 = aaa(1,2,:)+aaa(1,4,:) > 0.5;
%----2nd iteration backwards-----%
ccc = aaa;
for j = 1:img_size - 1
    i = img_size - j;
    ccc(:,1:2:3,i) = ccc(:,1:2,i+1)*[aaa(:,1:2:3,i).*p_x_2(1:2:3,1)'/sum(aaa(:,1:2:3,i).*p_x_2(1:2:3,1)'); aaa(:,1:2:3,i).*p_x_2(1:2:3,2)'/sum(aaa(:,1:2:3,i).*p_x_2(1:2:3,2)');];
    ccc(:,2:2:4,i) = ccc(:,3:4,i+1)*[aaa(:,2:2:4,i).*p_x_2(2:2:4,1)'/sum(aaa(:,2:2:4,i).*p_x_2(2:2:4,1)'); aaa(:,2:2:4,i).*p_x_2(2:2:4,2)'/sum(aaa(:,2:2:4,i).*p_x_2(2:2:4,2)');];
end
backward2 = sum(ccc(1,2:2:4,:)) > 0.5;
%-----3rd iteration forward----%
aaa = zeros(1,8,img_size);
bbb = p_3x;
for i = 3:img_size
    aaa(:,:,i) = bbb.*p_y_3(:,Y(i) + 1)' ./ sum(bbb.*p_y_3(:,Y(i) + 1)');
    bbb = [ aaa(1,1:4:5,i)*p_x_3(1:4:5,:) aaa(1,2:4:6,i)*p_x_3(2:4:6,:) aaa(1,3:4:7,i)*p_x_3(3:4:7,:) aaa(1,4:4:8,i)*p_x_3(4:4:8,:) ];
end
forward3 = sum(aaa(1,2:2:8,:)) > 0.5;
%----3rd iteration backwards-----%
ccc = aaa;
for j = 1:img_size - 1
    i = img_size - j;
    ccc(:,1:4:5,i) = ccc(:,1:2,i+1)*[aaa(:,1:4:5,i).*p_x_3(1:4:5,1)'/sum(aaa(:,1:4:5,i).*p_x_3(1:4:5,1)'); aaa(:,1:4:5,i).*p_x_3(1:4:5,2)'/sum(aaa(:,1:4:5,i).*p_x_3(1:4:5,2)');];
    ccc(:,2:4:6,i) = ccc(:,3:4,i+1)*[aaa(:,2:4:6,i).*p_x_3(2:4:6,1)'/sum(aaa(:,2:4:6,i).*p_x_3(2:4:6,1)'); aaa(:,2:4:6,i).*p_x_3(2:4:6,2)'/sum(aaa(:,2:4:6,i).*p_x_3(2:4:6,2)');];
    ccc(:,3:4:7,i) = ccc(:,5:6,i+1)*[aaa(:,3:4:7,i).*p_x_3(3:4:7,1)'/sum(aaa(:,3:4:7,i).*p_x_3(3:4:7,1)'); aaa(:,3:4:7,i).*p_x_3(3:4:7,2)'/sum(aaa(:,3:4:7,i).*p_x_3(3:4:7,2)');];
    ccc(:,4:4:8,i) = ccc(:,7:8,i+1)*[aaa(:,4:4:8,i).*p_x_3(4:4:8,1)'/sum(aaa(:,4:4:8,i).*p_x_3(4:4:8,1)'); aaa(:,4:4:8,i).*p_x_3(4:4:8,2)'/sum(aaa(:,4:4:8,i).*p_x_3(4:4:8,2)');];
end
final = sum(ccc(1,2:2:8,:)) > 0.5;
%--------Image Output-----------%
figure
imshow(reshape(final,size(pic,2),size(pic,1))'); 
error = xor(Y(:),B);
orig_error_rate=sum(error)/img_size
error = xor(final(:),B);
post_error_rate = sum(error)/img_size